{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "Jax to TFLite.ipynb",
      "provenance": [],
      "collapsed_sections": [],
      "toc_visible": true
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "8vD3L4qeREvg"
      },
      "source": [
        "##### Copyright 2021 The TensorFlow Authors."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "qLCxmWRyRMZE"
      },
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "4k5PoHrgJQOU"
      },
      "source": [
        "# Jax Model Conversion For TFLite\n",
        "## Overview\n",
        "Note: This API is new and only available via pip install tf-nightly. It will be available in TensorFlow version 2.7. Also, the API is still experimental and subject to changes.\n",
        "\n",
        "This CodeLab demonstrates how to build a model for MNIST recognition using Jax, and how to convert it to TensorFlow Lite. This codelab will also demonstrate how to optimize the Jax-converted TFLite model with post-training quantiztion."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "i8cfOBcjSByO"
      },
      "source": [
        "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://www.tensorflow.org/lite/examples/jax_conversion/overview\"><img src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" />View on TensorFlow.org</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/lite/g3doc/examples/jax_conversion/overview.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/g3doc/examples/jax_conversion/overview.ipynb\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View source on GitHub</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a href=\"https://storage.googleapis.com/tensorflow_docs/tensorflow/tensorflow/lite/g3doc/examples/jax_conversion/overview.ipynb\"><img src=\"https://www.tensorflow.org/images/download_logo_32px.png\" />Download notebook</a>\n",
        "  </td>\n",
        "</table>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "lq-T8XZMJ-zv"
      },
      "source": [
        "## Prerequisites\n",
        "It's recommended to try this feature with the newest TensorFlow nightly pip build."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "EV04hKdrnE4f"
      },
      "source": [
        "!pip install tf-nightly --upgrade\n",
        "!pip install jax --upgrade\n",
        "!pip install jaxlib --upgrade"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "QAeY43k9KM55"
      },
      "source": [
        "## Data Preparation\n",
        "Download the MNIST data with Keras dataset and pre-process."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "qSOPSZJn1_Tj"
      },
      "source": [
        "import numpy as np\n",
        "import tensorflow as tf\n",
        "import functools\n",
        "\n",
        "import time\n",
        "import itertools\n",
        "\n",
        "import numpy.random as npr\n",
        "\n",
        "import jax.numpy as jnp\n",
        "from jax import jit, grad, random\n",
        "from jax.experimental import optimizers\n",
        "from jax.experimental import stax\n"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "hdJIt3Da2Qn1"
      },
      "source": [
        "def _one_hot(x, k, dtype=np.float32):\n",
        "  \"\"\"Create a one-hot encoding of x of size k.\"\"\"\n",
        "  return np.array(x[:, None] == np.arange(k), dtype)\n",
        "\n",
        "(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()\n",
        "train_images, test_images = train_images / 255.0, test_images / 255.0\n",
        "train_images = train_images.astype(np.float32)\n",
        "test_images = test_images.astype(np.float32)\n",
        "\n",
        "train_labels = _one_hot(train_labels, 10)\n",
        "test_labels = _one_hot(test_labels, 10)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0eFhx85YKlEY"
      },
      "source": [
        "## Build the MNIST model with Jax"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "mi3TKB9nnQdK"
      },
      "source": [
        "def loss(params, batch):\n",
        "  inputs, targets = batch\n",
        "  preds = predict(params, inputs)\n",
        "  return -jnp.mean(jnp.sum(preds * targets, axis=1))\n",
        "\n",
        "def accuracy(params, batch):\n",
        "  inputs, targets = batch\n",
        "  target_class = jnp.argmax(targets, axis=1)\n",
        "  predicted_class = jnp.argmax(predict(params, inputs), axis=1)\n",
        "  return jnp.mean(predicted_class == target_class)\n",
        "\n",
        "init_random_params, predict = stax.serial(\n",
        "    stax.Flatten,\n",
        "    stax.Dense(1024), stax.Relu,\n",
        "    stax.Dense(1024), stax.Relu,\n",
        "    stax.Dense(10), stax.LogSoftmax)\n",
        "\n",
        "rng = random.PRNGKey(0)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "bRtnOBdJLd63"
      },
      "source": [
        "## Train & Evaluate the model"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "SWbYRyj7LYZt"
      },
      "source": [
        "step_size = 0.001\n",
        "num_epochs = 10\n",
        "batch_size = 128\n",
        "momentum_mass = 0.9\n",
        "\n",
        "\n",
        "num_train = train_images.shape[0]\n",
        "num_complete_batches, leftover = divmod(num_train, batch_size)\n",
        "num_batches = num_complete_batches + bool(leftover)\n",
        "\n",
        "def data_stream():\n",
        "  rng = npr.RandomState(0)\n",
        "  while True:\n",
        "    perm = rng.permutation(num_train)\n",
        "    for i in range(num_batches):\n",
        "      batch_idx = perm[i * batch_size:(i + 1) * batch_size]\n",
        "      yield train_images[batch_idx], train_labels[batch_idx]\n",
        "batches = data_stream()\n",
        "\n",
        "opt_init, opt_update, get_params = optimizers.momentum(step_size, mass=momentum_mass)\n",
        "\n",
        "@jit\n",
        "def update(i, opt_state, batch):\n",
        "  params = get_params(opt_state)\n",
        "  return opt_update(i, grad(loss)(params, batch), opt_state)\n",
        "\n",
        "_, init_params = init_random_params(rng, (-1, 28 * 28))\n",
        "opt_state = opt_init(init_params)\n",
        "itercount = itertools.count()\n",
        "\n",
        "print(\"\\nStarting training...\")\n",
        "for epoch in range(num_epochs):\n",
        "  start_time = time.time()\n",
        "  for _ in range(num_batches):\n",
        "    opt_state = update(next(itercount), opt_state, next(batches))\n",
        "  epoch_time = time.time() - start_time\n",
        "\n",
        "  params = get_params(opt_state)\n",
        "  train_acc = accuracy(params, (train_images, train_labels))\n",
        "  test_acc = accuracy(params, (test_images, test_labels))\n",
        "  print(\"Epoch {} in {:0.2f} sec\".format(epoch, epoch_time))\n",
        "  print(\"Training set accuracy {}\".format(train_acc))\n",
        "  print(\"Test set accuracy {}\".format(test_acc))"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "7Y1OZBhfQhOj"
      },
      "source": [
        "## Convert to TFLite model.\n",
        "Note here, we\n",
        "1. Inline the params to the Jax `predict` func with `functools.partial`.\n",
        "2. Build a `jnp.zeros`, this is a \"placeholder\" tensor used for Jax to trace the model.\n",
        "3. Call `experimental_from_jax`:\n",
        "> * The `serving_func` is wrapped in a list.\n",
        "> * The input is associated with a given name and passed in as an array wrapped in a list.\n",
        "\n",
        "\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "6pcqKZqdNTmn"
      },
      "source": [
        "serving_func = functools.partial(predict, params)\n",
        "x_input = jnp.zeros((1, 28, 28))\n",
        "converter = tf.lite.TFLiteConverter.experimental_from_jax(\n",
        "    [serving_func], [[('input1', x_input)]])\n",
        "tflite_model = converter.convert()\n",
        "with open('jax_mnist.tflite', 'wb') as f:\n",
        "  f.write(tflite_model)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "sqEhzaJPSPS1"
      },
      "source": [
        "## Check the Converted TFLite Model\n",
        "Compare the converted model's results with the Jax model."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "acj2AYzjSlaY"
      },
      "source": [
        "expected = serving_func(train_images[0:1])\n",
        "\n",
        "# Run the model with TensorFlow Lite\n",
        "interpreter = tf.lite.Interpreter(model_content=tflite_model)\n",
        "interpreter.allocate_tensors()\n",
        "input_details = interpreter.get_input_details()\n",
        "output_details = interpreter.get_output_details()\n",
        "interpreter.set_tensor(input_details[0][\"index\"], train_images[0:1, :, :])\n",
        "interpreter.invoke()\n",
        "result = interpreter.get_tensor(output_details[0][\"index\"])\n",
        "\n",
        "# Assert if the result of TFLite model is consistent with the JAX model.\n",
        "np.testing.assert_almost_equal(expected, result, 1e-5)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Qy9Gp4H2SjBL"
      },
      "source": [
        "## Optimize the Model\n",
        "We will provide a `representative_dataset` to do post-training quantiztion to optimize the model.\n",
        "\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "KI0rLV-Meg-2"
      },
      "source": [
        "def representative_dataset():\n",
        "  for i in range(1000):\n",
        "    x = train_images[i:i+1]\n",
        "    yield [x]\n",
        "\n",
        "converter = tf.lite.TFLiteConverter.experimental_from_jax(\n",
        "    [serving_func], [[('x', x_input)]])\n",
        "tflite_model = converter.convert()\n",
        "converter.optimizations = [tf.lite.Optimize.DEFAULT]\n",
        "converter.representative_dataset = representative_dataset\n",
        "converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]\n",
        "tflite_quant_model = converter.convert()\n",
        "with open('jax_mnist_quant.tflite', 'wb') as f:\n",
        "  f.write(tflite_quant_model)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "15xQR3JZS8TV"
      },
      "source": [
        "## Evaluate the Optimized Model"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "X3oOm0OaevD6"
      },
      "source": [
        "expected = serving_func(train_images[0:1])\n",
        "\n",
        "# Run the model with TensorFlow Lite\n",
        "interpreter = tf.lite.Interpreter(model_content=tflite_quant_model)\n",
        "interpreter.allocate_tensors()\n",
        "input_details = interpreter.get_input_details()\n",
        "output_details = interpreter.get_output_details()\n",
        "interpreter.set_tensor(input_details[0][\"index\"], train_images[0:1, :, :])\n",
        "interpreter.invoke()\n",
        "result = interpreter.get_tensor(output_details[0][\"index\"])\n",
        "\n",
        "# Assert if the result of TFLite model is consistent with the Jax model.\n",
        "np.testing.assert_almost_equal(expected, result, 1e-5)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "QqHXCNa3myor"
      },
      "source": [
        "## Compare the Quantized Model size\n",
        "We should be able to see the quantized model is four times smaller than the original model."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "imFPw007juVG"
      },
      "source": [
        "!du -h jax_mnist.tflite\n",
        "!du -h jax_mnist_quant.tflite"
      ],
      "execution_count": null,
      "outputs": []
    }
  ]
}